#----------------------------------------------------------------------
#  GFDM method test - 3d hollow sphere with internal power generation
#  Same test case of the CodeAster validation manual
#  Author: Andrea Pavan
#  Date: 09/01/2023
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using LinearAlgebra;
using SparseArrays;
using PyPlot;
using IterativeSolvers;
include("utils.jl");


#problem definition
Ri = 1;		#internal radius
Re = 2;		#external radius
alpha = 30*pi/180;      #domain angle
Ti = 20;        #internal radius temperature
Te = 20;        #external radius temperature
qv = 100;       #volumetric heat source
kconst = 1.0;       #thermal conductivity
rho = 1.0;      #material density
cp = 1.0;       #specific heat capacity

meshSize = 0.05;
surfaceMeshSize = meshSize;
minNeighbors = 40;
minSearchRadius = meshSize;


#read pointcloud from a SU2 file
time1 = time();
pointcloud = ElasticArray{Float64}(undef,3,0);      #3xN matrix containing the coordinates [X;Y;Z] of each node
boundaryNodes = Vector{Int}(undef,0);       #indices of the boundary nodes
internalNodes = Vector{Int}(undef,0);       #indices of the internal nodes
normals = ElasticArray{Float64}(undef,3,0);     #3xN matrix containing the components [nx;ny;nz] of the normal of each boundary node

pointcloud = parseSU2mesh("16_3d_heat_hollow_sphere_6084.su2");
#pointcloud = parseSU2mesh("16_3d_heat_hollow_sphere_23615.su2");
#pointcloud = parseSU2mesh("16_3d_heat_hollow_sphere_70326.su2");
cornerPoint = findall((pointcloud[2,:].<=1e-6).*(pointcloud[1,:].<=1e-6));
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
N = size(pointcloud,2);
pointcloudr = sqrt.(pointcloud[1,:].^2+pointcloud[2,:].^2+pointcloud[3,:].^2);
for i=1:N
    if pointcloudr[i]<=Ri+1e-6
        #internal surface
        push!(boundaryNodes, i);
        append!(normals, -pointcloud[:,i]./Ri);
    elseif pointcloudr[i]>=Re-1e-6
        #external surface
        push!(boundaryNodes, i);
        append!(normals, pointcloud[:,i]./Re);
    elseif pointcloud[2,i]<=1e-6
        #bottom surface
        push!(boundaryNodes, i);
        append!(normals, [0,-1,0]);
    elseif pointcloud[1,i]^2+pointcloud[3,i]^2<=(tan(pi/2-alpha)*pointcloud[2,i]+1e-6)^2
        #top surface
        push!(boundaryNodes, i);
        ncone = -[pointcloud[1,i],-pointcloud[2,i]*tan(pi/2-alpha)^2,pointcloud[3,i]];
        ncone ./= sqrt(ncone[1]^2+ncone[2]^2+ncone[3]^2);
        append!(normals, ncone);
    elseif pointcloud[3,i]<=1e-6
        #front surface
        push!(boundaryNodes, i);
        append!(normals, [0,0,-1]);
    elseif pointcloud[3,i]>=tan(alpha)*pointcloud[1,i]-1e-6
        #rear surface
        push!(boundaryNodes, i);
        append!(normals, [-sin(alpha),0,cos(alpha)]);
    else
        push!(internalNodes, i);
        append!(normals, [0,0,0]);
    end
end

println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));


#boundary conditions
N = size(pointcloud,2);     #number of nodes
g1 = zeros(Float64,N);
g2 = zeros(Float64,N);
g3 = zeros(Float64,N);
for i in boundaryNodes
    if pointcloud[1,i]^2+pointcloud[2,i]^2+pointcloud[3,i]^2<=Ri^2+1e-6
        #internal radius
        g1[i] = 1.0;
        g2[i] = 0.0;
        g3[i] = Ti;
    elseif pointcloud[1,i]^2+pointcloud[2,i]^2+pointcloud[3,i]^2>=Re^2-1e-6
        #external radius
        g1[i] = 1.0;
        g2[i] = 0.0;
        g3[i] = Te;
    else
        #everywhere else
        g1[i] = 0.0;
        g2[i] = 1.0;
        g3[i] = 0.0;
    end
end

#boundary conditions plot
#=figure();
plt = scatter3D(pointcloud[1,:],pointcloud[2,:],pointcloud[3,:],c=g3);
title("Numerical error");
axis("equal");
colorbar(plt);
display(gcf());=#


#neighbor search (cartesian cells)
time2 = time();
N = size(pointcloud,2);     #number of nodes
neighbors = Vector{Vector{Int}}(undef,N);       #vector containing N vectors of the indices of each node neighbors
Nneighbors = zeros(Int,N);      #number of neighbors of each node
boundaryNeighbors = Vector{Vector{Int}}(undef,N);
NboundaryNeighbors = zeros(Int,N);
(neighbors,Nneighbors,cell) = cartesianNeighborSearch(pointcloud,meshSize,minNeighbors);

println("Found neighbors in ", round(time()-time2,digits=2), " s");
println("Connectivity properties:");
println("  Max neighbors: ",maximum(Nneighbors)," (at index ",findfirst(isequal(maximum(Nneighbors)),Nneighbors),")");
println("  Avg neighbors: ",round(sum(Nneighbors)/length(Nneighbors),digits=2));
println("  Min neighbors: ",minimum(Nneighbors)," (at index ",findfirst(isequal(minimum(Nneighbors)),Nneighbors),")");


#neighbors distances and weights
time3 = time();
P = Vector{Array{Float64}}(undef,N);        #relative positions of the neighbors
r2 = Vector{Vector{Float64}}(undef,N);      #relative distances of the neighbors
w = Vector{Vector{Float64}}(undef,N);      #neighbors weights
wbn = Vector{Vector{Float64}}(undef,N);
for i=1:N
    P[i] = Array{Float64}(undef,3,Nneighbors[i]);
    r2[i] = Vector{Float64}(undef,Nneighbors[i]);
    w[i] = Vector{Float64}(undef,Nneighbors[i]);
    for j=1:Nneighbors[i]
        P[i][:,j] = pointcloud[:,neighbors[i][j]]-pointcloud[:,i];
        r2[i][j] = P[i][:,j]'P[i][:,j];
    end
    r2max = maximum(r2[i]);
    boundaryNeighbors[i] = Int[];
    NboundaryNeighbors[i] = 0;
    wbn[i] = Vector{Float64}(undef,0);
    for j=1:Nneighbors[i]
        w[i][j] = exp(-6*r2[i][j]/r2max);
        #w[i][j] = 1.0;
        if sum(normals[:,neighbors[i][j]])!=0
            push!(boundaryNeighbors[i], neighbors[i][j]);
            push!(wbn[i], w[i][j]/2);
            NboundaryNeighbors[i] += 1;
        end
    end
end
wpde = 2.0;       #least squares weight for the pde
wbc = 2.0;        #least squares weight for the boundary condition


#least square matrix inversion
C = Vector{Matrix}(undef,N);        #derivatives coefficients matrices
for i in internalNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    zj = P[i][3,:];
    V = zeros(Float64,1+Nneighbors[i]+NboundaryNeighbors[i],10);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, xj[j]*yj[j], xj[j]*zj[j], yj[j]*zj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 0, 2*kconst/(rho*cp), 2*kconst/(rho*cp), 2*kconst/(rho*cp), 0, 0, 0];
    for j=1:NboundaryNeighbors[i]
        (nx,ny,nz) = normals[:,boundaryNeighbors[i][j]];
        (xb,yb,zb) = pointcloud[:,boundaryNeighbors[i][j]];
        g1b = g1[boundaryNeighbors[i][j]];
        g2b = g2[boundaryNeighbors[i][j]];
        #V[1+Nneighbors[i]+j,:] = [0, nx, ny, nz, 2*xb*nx, 2*yb*ny, 2*zb*nz, yb*nx+xb*ny, zb*nx+xb*nz, yb*nz+zb*ny];
        V[1+Nneighbors[i]+j,1] = g1b;
        V[1+Nneighbors[i]+j,2] = g1b*xb + g2b*nx;
        V[1+Nneighbors[i]+j,3] = g1b*yb + g2b*ny;
        V[1+Nneighbors[i]+j,4] = g1b*zb + g2b*nz;
        V[1+Nneighbors[i]+j,5] = g1b*xb^2 + 2*g2b*nx*xb;
        V[1+Nneighbors[i]+j,6] = g1b*yb^2 + 2*g2b*ny*yb;
        V[1+Nneighbors[i]+j,7] = g1b*zb^2 + 2*g2b*nz*zb;
        V[1+Nneighbors[i]+j,8] = g1b*xb*yb + g2b*nx*yb + g2b*ny*xb;
        V[1+Nneighbors[i]+j,9] = g1b*xb*zb + g2b*nx*zb + g2b*nz*xb;
        V[1+Nneighbors[i]+j,10] = g1b*yb*zb + g2b*ny*zb + g2b*nz*yb;
    end
    W = Diagonal(vcat(w[i],wpde,wbn[i]));
    (Q,R) = qr(W*V);
    C[i] = inv(R)*transpose(Matrix(Q))*W;
end
for i in boundaryNodes
    #println("Boundary node: ",i);
    xj = P[i][1,:];
    yj = P[i][2,:];
    zj = P[i][3,:];
    V = zeros(Float64,2+Nneighbors[i],10);
    #V = zeros(Float64,2+Nneighbors[i]+NboundaryNeighbors[i],10);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, xj[j]*yj[j], xj[j]*zj[j], yj[j]*zj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 0, 2*kconst/(rho*cp), 2*kconst/(rho*cp), 2*kconst/(rho*cp), 0, 0, 0];
    V[2+Nneighbors[i],:] = [g1[i], g2[i]*normals[1,i], g2[i]*normals[2,i], g2[i]*normals[3,i], 0, 0, 0, 0, 0, 0];
    #=for j=1:NboundaryNeighbors[i]
        (nx,ny,nz) = normals[:,boundaryNeighbors[i][j]];
        (xb,yb,zb) = pointcloud[:,boundaryNeighbors[i][j]];
        V[2+Nneighbors[i]+j,:] = [0, nx, ny, nz, 2*xb*nx, 2*yb*ny, 2*zb*nz, yb*nx+xb*ny, zb*nx+xb*nz, yb*nz+zb*ny];
    end
    W = Diagonal(vcat(w[i],wpde,wbc,wbn[i]));=#
    W = Diagonal(vcat(w[i],wpde,wbc));
    (Q,R) = qr(W*V);
    C[i] = inv(R)*transpose(Matrix(Q))*W;
end
println("Inverted least-squares matrices in ", round(time()-time3,digits=2), " s");


#matrix assembly
time4 = time();
rows = Int[];
cols = Int[];
vals = Float64[];
for i=1:N
    push!(rows, i);
    push!(cols, i);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, i);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][1,j]);
    end
end
M = sparse(rows,cols,vals,N,N);
println("Completed matrix assembly in ", round(time()-time4,digits=2), " s");


#linear system solution
time5 = time();
b = zeros(N);       #rhs vector
for i in internalNodes
    b[i] = C[i][1,1+Nneighbors[i]]*(-qv)/(rho*cp);
    for j=1:NboundaryNeighbors[i]
        b[i] += C[i][1,1+Nneighbors[i]+j]*g3[boundaryNeighbors[i][j]];
    end
end
for i in boundaryNodes
    b[i] = C[i][1,1+Nneighbors[i]]*(-qv)/(rho*cp) + C[i][1,2+Nneighbors[i]]*g3[i];
    #=for j=1:NboundaryNeighbors[i]
        b[i] += C[i][1,2+Nneighbors[i]+j]*g3[boundaryNeighbors[i][j]];
    end=#
end
#u = M\b;
u = bicgstabl(M,b);
println("Linear system solved in ", round(time()-time5,digits=2), " s");


#solution plot
figure();
plt = scatter3D(pointcloud[1,:],pointcloud[2,:],pointcloud[3,:],c=u,cmap="jet");
title("Numerical solution");
axis("equal");
colorbar(plt);
display(gcf());

#validation plot
#plotIdx = findall((pointcloud[2,:].<=1e-6).*(pointcloud[3,:].<=1e-6));
plotIdx = collect(1:N);
rexact = LinRange(Ri,Re,100);
uexact(r) = @. Ti + (qv/(6*kconst))*(((Re^2-Ri^2)*(1/Ri-1/r))/(1/Ri-1/Re)-(r.^2-Ri^2));
figure();
plot(sqrt.(pointcloud[1,plotIdx].^2+pointcloud[2,plotIdx].^2+pointcloud[3,plotIdx].^2),u[plotIdx],"k.",label="GFDM");
plot(rexact,uexact(rexact),"r-",linewidth=1.0,label="Analytical");
title("Reference temperature");
legend(loc="upper left");
xlabel("Radius r");
ylabel("Temperature T");
display(gcf());

#numerical info
err = abs.(u-uexact.(pointcloudr));
maxerr = maximum(err);
rmse = sqrt.(sum(err.^2)/N);
println("umax = ", maximum(u));
println("maxerr = ", maxerr);
println("rmse = ", rmse);
#println("cond(M,1) = ",cond(M,1));

#max error plot
err = abs.(u-uexact.(pointcloudr));
figure();
plt = scatter3D(pointcloud[1,:],pointcloud[2,:],pointcloud[3,:],c=err,cmap="jet");
title("Numerical error");
axis("equal");
colorbar(plt);
display(gcf());
